package com.sxf.mybatis.interceptor;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.StringTokenizer;

import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import com.sxf.mybatis.dialect.Dialect;
import com.sxf.mybatis.help.ExRowBounds;
import com.sxf.mybatis.help.QueryArrayList;
import com.sxf.mybatis.help.QueryType;

/**
 * 分页类
 * 
 * @author phsxf01
 * 
 */
@Intercepts({ @Signature(type = Executor.class, method = "query", args = {
		MappedStatement.class, Object.class, RowBounds.class,
		ResultHandler.class }) })
@SuppressWarnings("unchecked")
public class PaginationInterceptor implements Interceptor {

	private static Dialect dialect;
	private static final Log log = LogFactory.getLog(PreparedStatement.class);

	public Object intercept(Invocation invocation) throws Throwable {
		// Executor executor = (Executor) invocation.getTarget();
		Object result = null;
		Object rowBounds = invocation.getArgs()[2];
		if (rowBounds != null && !(rowBounds == RowBounds.DEFAULT)
				&& rowBounds instanceof ExRowBounds) {
			ExRowBounds bounds = (ExRowBounds) rowBounds;
			result = pageIntercept(invocation, bounds);
		} else {
			result = invocation.proceed();
		}
		return result;
	}

	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	public void setProperties(Properties properties) {
		String dialectClass = properties.getProperty("dialectClass");
		try {
			if (dialect == null) {
				dialect = (Dialect) Class.forName(dialectClass).newInstance();
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	private Object pageIntercept(Invocation invocation, ExRowBounds bounds)
			throws Throwable {
		long rowCount = 0;
		Executor executor = (Executor) invocation.getTarget();
		MappedStatement mappedStatement = (MappedStatement) invocation
				.getArgs()[0];
		Object parameterObject = invocation.getArgs()[1];
		ResultHandler resultHandler = (ResultHandler) invocation.getArgs()[3];
		BoundSql boundSql = mappedStatement.getBoundSql(parameterObject);
		QueryType queryType = QueryType.OBJECT;
		queryType = bounds.getQueryType();
		switch (queryType) {
		case NUMBER: {
			rowCount = getTotalCountByConnection(mappedStatement,
					parameterObject);
		}
			break;
		case OBJECT_NUMBER: {
			// 设置统计总数
			rowCount = getTotalCountByConnection(mappedStatement,
					parameterObject);
			// 设置分页
			MetaObject metaObject = SystemMetaObject.forObject(boundSql);
			metaObject.setValue("sql",
					getPageLimitSql(boundSql.getSql(), bounds));
		}
			break;
		// 默认查询QUERY_OBJECT
		default: {
			// 设置分页
			MetaObject metaObject = SystemMetaObject.forObject(boundSql);
			metaObject.setValue("sql",
					getPageLimitSql(boundSql.getSql(), bounds));
		}
		}
		Object result = null;
		if (queryType != QueryType.NUMBER) {
			// 代替Executor的query方法
			CacheKey cacheKey = executor.createCacheKey(mappedStatement,
					parameterObject, RowBounds.DEFAULT, boundSql);
			result = executor.query(mappedStatement, parameterObject,
					RowBounds.DEFAULT, resultHandler, cacheKey, boundSql);
			// obj = invocation.proceed();
		}
		List<Object> objList = new ArrayList<Object>();// 空list会被替换,或者只有一个元素
		if (result != null) {
			if (result instanceof List) {
				objList = (List<Object>) result;
			} else {
				objList.add(result);
			}
		}
		result = new QueryArrayList<Object>(objList, rowCount);
		return result;
	}

	/**
	 * @param mappedStatement
	 * @param parameterObject
	 * @return
	 * @throws Exception
	 */
	private long getTotalCountByConnection(MappedStatement mappedStatement,
			Object parameterObject) throws Exception {
		long rowCount = 0;
		BoundSql boundSql = mappedStatement.getBoundSql(parameterObject);
		String sql = getCountSql(boundSql.getSql());
		if (log.isDebugEnabled()) {
			log.debug("==>>>  Executing statistics total: "
					+ removeBreakingWhitespace(sql));
		}
		// MappedStatement newMappedStatement =
		// mappedStatement.getConfiguration()
		// .getMappedStatement(mappedStatement.getId());
		Connection connection = mappedStatement.getConfiguration()
				.getEnvironment().getDataSource().getConnection();
		DefaultParameterHandler dpHandler = new DefaultParameterHandler(
				mappedStatement, parameterObject, boundSql);
		PreparedStatement countStmt = connection.prepareStatement(sql);
		dpHandler.setParameters(countStmt);
		ResultSet rs = countStmt.executeQuery();
		if (rs.next()) {
			rowCount = rs.getLong(1);
		}
		rs.close();
		countStmt.close();
		connection.close();
		return rowCount;
	}

	private String getCountSql(String sql) throws Exception {
		StringBuilder sb = new StringBuilder();
		sb.append("select count(1) from ( ");
		sb.append(sql);
		sb.append(" ) t");
		String countSql = sb.toString();
		return countSql;
	}

	/**
	 * 分页语句
	 * 
	 * @param sql
	 * @param bounds
	 * @return
	 * @throws Exception
	 */
	private String getPageLimitSql(String sql, ExRowBounds bounds)
			throws Exception {
		String newSql = null;
		if (bounds.getLimit() > 0 && bounds.getLimit() < RowBounds.NO_ROW_LIMIT) {
			newSql = dialect.getLimitString(sql, bounds.getOffset(),
					bounds.getLimit());
		} else {
			newSql = sql;
		}
		return newSql;
	}

	private String removeBreakingWhitespace(String original) {
		StringTokenizer whitespaceStripper = new StringTokenizer(original);
		StringBuilder builder = new StringBuilder();
		while (whitespaceStripper.hasMoreTokens()) {
			builder.append(whitespaceStripper.nextToken());
			builder.append(" ");
		}
		return builder.toString();
	}
}